{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "\n",
    "\n",
    "## 1. gather操作\n",
    "\n",
    "gather的pytorch官方文档是这样写的：\n",
    "\n",
    "<img src=\"gather_官方文档.PNG\"  width=\"800\" height=\"1200\" align=\"left\" />\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 理解gather操作\n",
    "\n",
    "`tensor_A.gather(dim, index)`: 按照 index 和指定的dim，取出tensor_A中的值；\n",
    "\n",
    "(1) index和源tensor_A维度一致;\n",
    "\n",
    "(3) 注意如何根据index选取tensor_A中的值： \n",
    "\n",
    "对于2-D tensor:\n",
    "\n",
    "    if dim=0, output[i][j] = tensor_A[index[i][j]][j];\n",
    "    \n",
    "    if dim=1, output[i][j] = tensor_A[i][index[i][j]];"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 举例"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "tensor_A = 10 * torch.randn(3, 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[  2.8331, -12.4045, -21.8001,  -2.0757,   2.1819],\n",
       "        [ -4.8956,   6.3741,   7.7292,  -2.6811,  13.2136],\n",
       "        [ -5.2171, -16.6636, -11.2940,  -7.7390,  16.0390]])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tensor_A"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[0, 1, 2, 0, 0],\n",
      "        [2, 0, 0, 1, 2]])\n",
      "torch.Size([2, 5])\n"
     ]
    }
   ],
   "source": [
    "index = [\n",
    "    [0, 1, 2, 0, 0], \n",
    "    [2, 0, 0, 1, 2],\n",
    "]\n",
    "index = torch.tensor(index, dtype=torch.long)\n",
    "print(index)\n",
    "print(index.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[  2.8331,   6.3741, -11.2940,  -2.0757,   2.1819],\n",
       "        [ -5.2171, -12.4045, -21.8001,  -2.6811,  16.0390]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# dim = 0\n",
    "tensor_C = tensor_A.gather(0, index)\n",
    "tensor_C"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "i=1, j=3; \n",
    "\n",
    "    tensor_A[index[i][j]][j] = tensor_A[1][3] = -2.6811\n",
    "    \n",
    "    output[i][j] = output[1][3] = -2.6811\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[  2.8331, -12.4045, -21.8001,   2.8331,   2.8331],\n",
       "        [  7.7292,  -4.8956,  -4.8956,   6.3741,   7.7292]])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# dim = 1\n",
    "tensor_D = tensor_A.gather(1, index)\n",
    "tensor_D"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "i=1, j=3; \n",
    "\n",
    "    tensor_A[i][index[i][j]] = tensor_A[1][1] = 6.3741\n",
    "    \n",
    "    output[i][j] = output[1][3] = 6.3741\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "PyCharm (第七课_代码)",
   "language": "python",
   "name": "pycharm-76f391d8"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}